intro
intro.RmdTo illustrate the regression case, we’ll use the Hitters
data from the ISLR2 package. A
sample of the data are shown below.
data("Hitters", package = "ISLR2")
# Remove rows with missing response values
head(hitters <- Hitters[!is.na(Hitters$Salary), ])
#> AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun
#> -Alan Ashby 315 81 7 24 38 39 14 3449 835 69
#> -Alvin Davis 479 130 18 66 72 76 3 1624 457 63
#> -Andre Dawson 496 141 20 65 78 37 11 5628 1575 225
#> -Andres Galarraga 321 87 10 39 42 30 2 396 101 12
#> -Alfredo Griffin 594 169 4 74 51 35 11 4408 1133 19
#> -Al Newman 185 37 1 23 8 21 2 214 42 1
#> CRuns CRBI CWalks League Division PutOuts Assists Errors
#> -Alan Ashby 321 414 375 N W 632 43 10
#> -Alvin Davis 224 266 263 A W 880 82 14
#> -Andre Dawson 828 838 354 N E 200 11 3
#> -Andres Galarraga 48 46 33 N E 805 40 4
#> -Alfredo Griffin 501 336 194 A W 282 421 25
#> -Al Newman 30 9 24 N E 76 127 7
#> Salary NewLeague
#> -Alan Ashby 475.0 N
#> -Alvin Davis 480.0 A
#> -Andre Dawson 500.0 N
#> -Andres Galarraga 91.5 N
#> -Alfredo Griffin 750.0 A
#> -Al Newman 70.0 AWe’ll start by fitting a basic EBMRegressor to the
hitters data set. Note that the ebm() function
currently only supports the usual formula interface.
library(ebm)
# Fit a default EBM regressor
fit <- ebm(Salary ~ ., data = hitters, objective = "rmse")
fit # still need to implement print() and summary() methods
#> ExplainableBoostingRegressor(early_stopping_tolerance=0)You can obtain predictions using the familiar predict()
method employed by most modeling packages in R. Note that thorugh
bagging, EBMs can provide standard errors for the predictions if
requested.
head(predict(fit, newdata = hitters))
#> [1] 489.4548 626.1970 870.3430 169.8797 659.6543 270.5382
head(predict(fit, newdata = hitters, se.fit = TRUE))
#> [,1] [,2]
#> [1,] 489.4548 53.92471
#> [2,] 626.1970 98.39167
#> [3,] 870.3430 241.04199
#> [4,] 169.8797 43.14385
#> [5,] 659.6543 54.02381
#> [6,] 270.5382 155.71903You can produce several plotly-based graphs to help interpret
the output of "EBM" objects using the generic
plot() method; this function supports both global and local
interpretations. The default simply prints a global measure of
importance based on the sum of the absolute value of each variable sterm
contributions. (For Markdown-type documents, like this vignette, you
need to specify display = "markdown"; see
?ebm::plot for details.)
You can also plot the individual shape functions (or term contributions), as shown below:
While the ebm package does not expose 100% of the
functionality available in Python, you can pretty much do anything you
need by interacting directly with the underlying Python objects (the
magic happens through reticulate). For
instance, we invoke the monottonize() method to post
process the term for "Years" by enforcing increasing
monotonicity:
fit$monotonize("Years", increasing = FALSE)
#> ExplainableBoostingRegressor(early_stopping_tolerance=0)
plot(fit, term = "Years", display = "markdown")You can also display local explanations (though, one at a time) by
specifying local = TRUE:
# Understand an individual prediction
x <- subset(hitters, select = -Salary)[1L, ] # use first observation
plot(fit, local = TRUE, X = x, y = hitters$Salary[1L], display = "markdown")